Skip to content

[CUDA] Reduce the number of stream-k blocks to reduce the overhead of the flash_attn_stream_k_fixup kernel#21086

Closed
gaugarg-nv wants to merge 3 commits intoggml-org:masterfrom
gaugarg-nv:reduce_stream_k_block
Closed

[CUDA] Reduce the number of stream-k blocks to reduce the overhead of the flash_attn_stream_k_fixup kernel#21086
gaugarg-nv wants to merge 3 commits intoggml-org:masterfrom
gaugarg-nv:reduce_stream_k_block

Conversation

@gaugarg-nv
Copy link
Copy Markdown
Contributor

@gaugarg-nv gaugarg-nv commented Mar 27, 2026

For GPUs with high SM counts, the number of stream-k blocks can be very high to fill the entire GPU. In such cases, flash_attn_stream_k_fixup takes significant time.

The fix is to reduce the number of stream-k blocks. For example, in such cases, if max_blocks_per_sm is 2 or 4, reduce it by a factor of 2. This can reduce occupancy, but I am seeing positive gains with this change.

Future work: Explore how to optimize flash_attn_stream_k_fixup for a large number of blocks.

Performance
gpu_info model_type n_ubatch n_prompt n_depth Master-avg_ts PR-avg_ts Speed-up 
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 1 512 8192 211.0296 212.206 1.01
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 2 512 8192 256.9456 298.658 1.16
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 4 512 8192 418.3045 469.4834 1.12
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 8 512 8192 541.8795 580.987 1.07
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 1 512 8192 361.2679 385.7475 1.07
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 2 512 8192 487.9653 512.7976 1.05
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 4 512 8192 725.512 746.1654 1.03
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 8 512 8192 971.5393 989.0909 1.02
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 1 512 8192 272.3259 272.5065 1.00
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 2 512 8192 458.9721 459.2129 1.00
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 4 512 8192 669.9135 721.8211 1.08
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 8 512 8192 1088.661 1158.717 1.06
NVIDIA GeForce RTX 5090 llama 8B Q4_0 1 512 8192 233.3957 233.4368 1.00
NVIDIA GeForce RTX 5090 llama 8B Q4_0 2 512 8192 417.9664 418.6286 1.00
NVIDIA GeForce RTX 5090 llama 8B Q4_0 4 512 8192 714.1412 768.6836 1.08
NVIDIA GeForce RTX 5090 llama 8B Q4_0 8 512 8192 1091.562 1156.124 1.06
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 1 512 8192 84.61194 84.67307 1.00
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 2 512 8192 149.7183 156.259 1.04
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 4 512 8192 292.13 303.9734 1.04
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 8 512 8192 519.2778 539.2146 1.04
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 1 512 8192 63.07156 63.29621 1.00
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 2 512 8192 107.9662 114.3425 1.06
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 4 512 8192 167.2002 178.8155 1.07
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 8 512 8192 224.0947 240.8187 1.07
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 1 512 16384 191.0213 189.605 0.99
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 2 512 16384 241.0453 275.4833 1.14
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 4 512 16384 395.3927 439.7267 1.11
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 8 512 16384 522.4441 554.3349 1.06
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 1 512 16384 346.9191 358.5496 1.03
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 2 512 16384 472.5971 495.9779 1.05
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 4 512 16384 709.3301 727.3115 1.03
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 8 512 16384 957.0377 973.5629 1.02
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 1 512 16384 224.2367 223.4497 1.00
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 2 512 16384 389.1294 388.2897 1.00
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 4 512 16384 588.4252 632.4903 1.07
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 8 512 16384 977.6103 1033.039 1.06
NVIDIA GeForce RTX 5090 llama 8B Q4_0 1 512 16384 201.3511 200.9778 1.00
NVIDIA GeForce RTX 5090 llama 8B Q4_0 2 512 16384 365.8022 364.9601 1.00
NVIDIA GeForce RTX 5090 llama 8B Q4_0 4 512 16384 633.861 680.0114 1.07
NVIDIA GeForce RTX 5090 llama 8B Q4_0 8 512 16384 992.5401 1042.028 1.05
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 1 512 16384 78.82764 78.77085 1.00
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 2 512 16384 140.3304 146.4556 1.04
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 4 512 16384 274.16 284.4379 1.04
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 8 512 16384 487.1352 499.993 1.03
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 1 512 16384 58.04407 57.97893 1.00
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 2 512 16384 100.3653 105.5507 1.05
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 4 512 16384 158.2177 160.0042 1.01
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 8 512 16384 216.1302 212.9512 0.99
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 1 512 32768 158.2302 157.8661 1.00
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 2 512 32768 213.2867 240.3428 1.13
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 4 512 32768 356.535 392.2516 1.10
NVIDIA GeForce RTX 5090 qwen3moe 30B.A3B Q4_0 8 512 32768 482.2693 507.9359 1.05
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 1 512 32768 318.1751 336.3647 1.06
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 2 512 32768 445.1713 465.8869 1.05
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 4 512 32768 673.0355 689.9298 1.03
NVIDIA GeForce RTX 5090 gpt-oss 20B MXFP4 MoE 8 512 32768 912.4386 930.0826 1.02
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 1 512 32768 165.9443 165.3803 1.00
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 2 512 32768 302.5722 302.1907 1.00
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 4 512 32768 478.7388 508.992 1.06
NVIDIA GeForce RTX 5090 qwen3 4B Q4_K - Medium 8 512 32768 813.9919 860.0112 1.06
NVIDIA GeForce RTX 5090 llama 8B Q4_0 1 512 32768 156.9025 156.4821 1.00
NVIDIA GeForce RTX 5090 llama 8B Q4_0 2 512 32768 294.843 294.3899 1.00
NVIDIA GeForce RTX 5090 llama 8B Q4_0 4 512 32768 522.1251 553.6983 1.06
NVIDIA GeForce RTX 5090 llama 8B Q4_0 8 512 32768 841.0377 882.3715 1.05
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 1 512 32768 70.05649 70.02391 1.00
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 2 512 32768 125.872 130.9071 1.04
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 4 512 32768 245.0809 254.5865 1.04
NVIDIA GeForce RTX 5090 qwen3 14B Q8_0 8 512 32768 435.5816 453.5067 1.04
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 1 512 32768 50.51349 50.50699 1.00
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 2 512 32768 88.64097 92.88735 1.05
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 4 512 32768 141.5971 145.0327 1.02
NVIDIA GeForce RTX 5090 qwen2 32B Q4_K - Medium 8 512 32768 198.585 202.4038 1.02

This change is also helpful for Tensor parallelism (PR #19378), specifically for gpt-oss, which uses the stream-k path.

Tensor Parallelism Performance on 2x RTX Pro 6000 Blackwell with PR 19378
gpu_info model_type n_prompt n_gen n_depth 6e31365-avg_ts PR-avg_ts  Speed-up
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 512 0 0 14856.48 14887.14 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 0 128 0 302.3649 305.1821 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 512 0 4192 13864.22 13865.99 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 0 128 4192 294.4124 296.2428 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 512 0 8192 13231.02 13217.55 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 0 128 8192 272.3426 282.8567 1.04
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 512 0 16384 12030.71 12042.74 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 20B MXFP4 MoE 0 128 16384 243.0565 273.1198 1.12
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 512 0 0 7507.403 7558.003 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 0 128 0 207.4498 209.0234 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 512 0 4192 7235.533 7318.898 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 0 128 4192 200.5357 204.1091 1.02
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 512 0 8192 7000.607 7158.927 1.02
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 0 128 8192 186.6602 197.5134 1.06
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 512 0 16384 6602.919 6731.021 1.02
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition gpt-oss 120B MXFP4 MoE 0 128 16384 166.7649 189.2567 1.13
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 512 0 0 16012.23 16177.25 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 0 128 0 253.2106 258.8453 1.02
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 512 0 4192 13688.02 14059.05 1.03
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 0 128 4192 241.1187 245.5425 1.02
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 512 0 8192 12542.14 12924.33 1.03
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 0 128 8192 232.2564 234.3304 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 512 0 16384 10922.52 11332.17 1.04
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 8B Q4_0 0 128 16384 213.3187 215.0938 1.01
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 512 0 0 2858.523 2867.346 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 0 128 0 51.00086 51.25426 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 512 0 4192 2642.947 2650.761 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 0 128 4192 49.19993 49.21726 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 512 0 8192 2531.469 2527.46 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 0 128 8192 44.20208 44.22822 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 512 0 16384 2289.982 2287.164 1.00
NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Server Edition llama 70B Q4_0 0 128 16384 42.43172 42.45479 1.00

Additional information

Requirements

@gaugarg-nv gaugarg-nv requested a review from a team as a code owner March 27, 2026 22:14
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Mar 27, 2026
IMbackK
IMbackK previously approved these changes Mar 27, 2026
@IMbackK IMbackK dismissed their stale review March 27, 2026 22:21

Missclick

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I don't think this is the correct logic. Asymptotically for an infinitely deep context it should always be worthwhile to run as many CUDA blocks as possible since the overhead becomes negligible. Intuitively I would expect something like this to be a better solution: if possible, always run at least 2 blocks / SM in order to keep the GPU busy when calling __syncthreads(). After that, only increase the number of blocks / SM while each block gets some minimum chunk of the KV cache to work on.

@gaugarg-nv
Copy link
Copy Markdown
Contributor Author

I don't think this is the correct logic. Asymptotically for an infinitely deep context it should always be worthwhile to run as many CUDA blocks as possible since the overhead becomes negligible. Intuitively I would expect something like this to be a better solution: if possible, always run at least 2 blocks / SM in order to keep the GPU busy when calling __syncthreads(). After that, only increase the number of blocks / SM while each block gets some minimum chunk of the KV cache to work on.

I agree with what you are saying. But in practice, we are seeing good speed-up even for models that have max_blocks_per_sm = 2 and this PR reduces it to 1 block per SM. It's possible if we continue increasing context length, the perf benefits might reduce. I will try to collect that data.

I will also explore if there is a better way to reduce the overhead of flash_attn_stream_k_fixup without reducing blocks per SM.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

An RTX 5090 and an RTX Pro 6000 both have 192 SMs with Qwen 3 30b a3b with 4 KV heads that will be on average a 171 chunk / SM. The kernel will run with an internal batch size (nbatch_fa) of either 64 or 128 depending on Q->ne[1]. It may make sense to try increasing the minimum chunk size / SM to something like 256.

As for the stream-k fixup itself: if the occupancy is low anyways we could maybe run the kernel with a number of CUDA blocks that is an exact multiple of ntiles_dst and write an alternative fixup kernel with simplified logic that will maybe be faster.

@gaugarg-nv
Copy link
Copy Markdown
Contributor Author

An RTX 5090 and an RTX Pro 6000 both have 192 SMs with Qwen 3 30b a3b with 4 KV heads that will be on average a 171 chunk / SM. The kernel will run with an internal batch size (nbatch_fa) of either 64 or 128 depending on Q->ne[1]. It may make sense to try increasing the minimum chunk size / SM to something like 256.

I'm not sure if this will help in all cases. As you can see in the perf data, this change is helping even for 32K context length, where ntiles_KV*ntiles_dst is much bigger than max_blocks. In such cases, the chunk size per SM should already be large enough.

As for the stream-k fixup itself: if the occupancy is low anyways we could maybe run the kernel with a number of CUDA blocks that is an exact multiple of ntiles_dst and write an alternative fixup kernel with simplified logic that will maybe be faster.

Thanks for the idea. I will look into it.

@gaugarg-nv
Copy link
Copy Markdown
Contributor Author

Closing this PR in favor of #21159

@gaugarg-nv gaugarg-nv closed this Mar 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants